import sys
from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QHBoxLayout,
                             QLabel, QPushButton, QTextEdit,
                             QVBoxLayout, QSlider, QDesktopWidget, QMainWindow)

from PyQt5.QtCore import QTimer, QTime, QCoreApplication, Qt
from PyQt5.QtGui import QFont


from davisinteractive.metrics import batched_jaccard, batched_f_measure
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas

import matplotlib.pyplot as plt
import numpy as np
import time
import os
import csv
import random
from datetime import datetime

from davisinteractive.utils.visualization import overlay_mask, _pascal_color_map
from libs import utils_custom
from PIL import Image, ImageFont, ImageDraw

class App(QWidget):
    def __init__(self, DIE, model, root, video_indices, save_imgs=False):
        super().__init__()
        self.DIE = DIE
        self.model = model
        self.root = root
        sequence_list = DIE.videos
        self.sequence_list = sequence_list
        self.video_indices = video_indices
        self.video_idx = 0
        if self.video_indices is not None:
            self.video_idx = self.video_indices
        self.sequence = sequence_list[self.video_idx]
        print(str(str(self.video_idx) + self.sequence))
        self.frames = utils_custom.load_frames(os.path.join(root, 'JPEGImages', '480p', self.sequence))  # f h w 3
        self.num_frames, self.height, self.width = self.frames.shape[:3]
        self.vis_frames = self.frames.copy()
        self.gts_overlayed = self.frames.copy()
        self.gts = utils_custom.load_gts_multi(os.path.join(root, 'Annotations', '480p', self.sequence))
        for fr in range(self.num_frames):
            self.gts_overlayed[fr] = overlay_mask(self.frames[fr], self.gts[fr], alpha=0.4, contour_thickness=2)
        self.cmap = _pascal_color_map()
        self.n_obj = self.gts.max()
        self._palette = Image.open('etc/00000.png').getpalette()

        # font declare
        font_helveltica = 'etc/fonts/helvetica.ttf'
        self.selected_font = ImageFont.truetype(font_helveltica, size=20)

        # init model
        self.model.init_with_new_video(self.frames, self.n_obj)
        self.current_object = 1

        # Other variables
        self.first_scr = None
        self.current_round = 0
        self.scribble_timesteps = []
        self.operate_timesteps = []
        self.finding_timesteps = []
        self.VOS_once_executed_bool = False
        self.not_started = True
        self.after_candidates_decided = True
        self.candidate_frames = []

        self.text_print = ''
        self.save_imgs = save_imgs

        # window settings
        self.setWindowTitle('Demo: CVPR2021_GIS-RAmap')
        self.setGeometry(100, 100, int(self.width*1.2)+800, (int(self.height*1.2)+200))
        qr = self.frameGeometry()
        cp = QDesktopWidget().availableGeometry().center()
        qr.moveCenter(cp)
        self.move(qr.topLeft())
        self.show()

        # object buttons
        self.obj1_button = QPushButton('\nAnnotate \nobject 1 [1]\n')
        self.obj1_button.clicked.connect(self.obj1_pressed)
        self.obj1_button.setMaximumHeight(80)
        self.obj1_button.setStyleSheet("background-color: red")
        self.obj1_button.setCheckable(True)
        self.obj1_button.setShortcut('1')

        self.obj2_button = QPushButton('\nAnnotate \nobject 2 [2]\n')
        self.obj2_button.clicked.connect(self.obj2_pressed)
        self.obj2_button.setMaximumHeight(80)
        self.obj2_button.setStyleSheet("background-color: green")
        self.obj2_button.setCheckable(True)

        self.obj3_button = QPushButton('\nAnnotate \nobject 3 [3]\n')
        self.obj3_button.clicked.connect(self.obj3_pressed)
        self.obj3_button.setMaximumHeight(80)
        self.obj3_button.setStyleSheet("background-color: yellow")
        self.obj3_button.setCheckable(True)

        self.obj4_button = QPushButton('\nAnnotate \nobject 4 [4]\n')
        self.obj4_button.clicked.connect(self.obj4_pressed)
        self.obj4_button.setMaximumHeight(80)
        self.obj4_button.setStyleSheet("background-color: blue")
        self.obj4_button.setCheckable(True)

        self.obj5_button = QPushButton('\nAnnotate \nobject 5 [5]\n')
        self.obj5_button.clicked.connect(self.obj5_pressed)
        self.obj5_button.setMaximumHeight(80)
        self.obj5_button.setStyleSheet("background-color: purple")
        self.obj5_button.setCheckable(True)

        if self.n_obj>=2:
            self.obj2_button.setShortcut('2')
            if self.n_obj>=3:
                self.obj3_button.setShortcut('3')
                if self.n_obj>=4:
                    self.obj4_button.setShortcut('4')
                    if self.n_obj>=5:
                        self.obj5_button.setShortcut('5')

        # buttons
        self.prev_button = QPushButton('Prev [<-]')
        self.prev_button.clicked.connect(self.on_prev)
        self.prev_button.setShortcut(Qt.Key_Left)
        self.next_button = QPushButton('Next [->]')
        self.next_button.clicked.connect(self.on_next)
        self.next_button.setShortcut(Qt.Key_Right)
        self.play_button = QPushButton('Play [P]')
        self.play_button.clicked.connect(self.on_play)
        self.play_button.setShortcut('P')
        self.restart_button = QPushButton('Restart the video')
        self.restart_button.clicked.connect(self.restart_video)
        self.run_button = QPushButton('Run VOS [R]')
        self.run_button.pressed.connect(self.on_run_dschange)
        self.run_button.clicked.connect(self.on_run)
        self.run_button.setShortcut('R')
        self.end_button = QPushButton('Satisfied [S]')
        self.end_button.clicked.connect(self.on_end)
        self.end_button.setShortcut('S')

        self.cand1_button = QPushButton('Candidate A [A]')
        self.cand1_button.clicked.connect(self.on_candidateA)
        self.cand1_button.setShortcut('A')
        self.cand2_button = QPushButton('Candidate B [B]')
        self.cand2_button.clicked.connect(self.on_candidateB)
        self.cand2_button.setShortcut('B')
        self.cand3_button = QPushButton('Candidate C [C]')
        self.cand3_button.clicked.connect(self.on_candidateC)
        self.cand3_button.setShortcut('C')
        self.cand4_button = QPushButton('Candidate D [D]')
        self.cand4_button.clicked.connect(self.on_candidateD)
        self.cand4_button.setShortcut('D')

        # LCD
        self.lcd1 = QTextEdit()
        self.lcd1.setReadOnly(True)
        self.lcd1.setMaximumHeight(28)
        self.lcd1.setMaximumWidth(100)
        self.lcd1.setText('{: 3d} / {: 3d}'.format(0, self.num_frames-1))

        # LCD#2
        self.lcd2 = QTextEdit()
        self.lcd2.setReadOnly(True)
        self.lcd2.setMaximumHeight(28)
        self.lcd2.setMaximumWidth(self.width)
        self.lcd2.setText('Current round : {:02d}'.format(self.current_round+1))

        # LCD#3
        self.lcd3 = QTextEdit()
        self.lcd3.setReadOnly(True)
        self.lcd3.setMaximumHeight(600)
        self.lcd3.setMaximumWidth(600)
        self.text_print += 'Round [{:02d}]\n'.format(self.current_round+1)
        self.lcd3.setText(self.text_print)

        # slide
        self.slider = QSlider(Qt.Horizontal)
        self.slider.setMinimum(0)
        self.slider.setMaximum(self.num_frames-1)
        self.slider.setValue(0)
        self.slider.setTickPosition(QSlider.TicksBelow)
        self.slider.setTickInterval(1)
        self.slider.valueChanged.connect(self.slide)

        # main figure
        self.fig1 = plt.Figure()
        self.ax1 = plt.Axes(self.fig1, [0., 0., 1., 1.])
        self.ax1.set_axis_off()
        self.fig1.add_axes(self.ax1)
        self.canvas1 = FigureCanvas(self.fig1)

        self.cidpress = self.fig1.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidrelease = self.fig1.canvas.mpl_connect('button_release_event', self.on_release)
        self.cidmotion = self.fig1.canvas.mpl_connect('motion_notify_event', self.on_motion)

        # sub figure
        self.fig2 = plt.Figure()
        self.ax2 = plt.Axes(self.fig2, [0., 0., 1., 1.])
        self.ax2.set_axis_off()
        self.fig2.add_axes(self.ax2)
        self.canvas2 = FigureCanvas(self.fig2)
        self.canvas2.setMaximumHeight(320)
        self.canvas2.setMaximumWidth(600)
        self.label1 = QLabel('Ground-Truth', self)
        self.label1.setAlignment(Qt.AlignCenter)
        font1 = self.label1.font()
        font1.setPointSize(20)


        # object buttons
        obj_buttons = QVBoxLayout()
        obj_buttons.addSpacing(20)
        obj_buttons.addWidget(self.obj1_button)
        obj_buttons.addWidget(self.obj2_button)
        obj_buttons.addWidget(self.obj3_button)
        obj_buttons.addWidget(self.obj4_button)
        obj_buttons.addWidget(self.obj5_button)
        obj_buttons.addSpacing(20)


        # navigator for layout
        navi = QHBoxLayout()
        navi.addWidget(self.lcd1)
        navi.addWidget(self.prev_button)
        navi.addWidget(self.play_button)
        navi.addWidget(self.next_button)
        navi.addStretch(1)
        navi.addStretch(1)
        navi.addWidget(self.restart_button)
        navi.addWidget(self.run_button)
        navi.addWidget(self.end_button)

        navi2 = QHBoxLayout()
        navi2.addWidget(self.cand1_button)
        navi2.addWidget(self.cand2_button)
        navi3 = QHBoxLayout()
        navi3.addWidget(self.cand3_button)
        navi3.addWidget(self.cand4_button)

        # main layout
        layout_main = QVBoxLayout()
        layout_main.addWidget(self.canvas1)
        layout_main.addWidget(self.slider)
        layout_main.addWidget(self.lcd2)
        layout_main.addLayout(navi)
        layout_main.addLayout(navi2)
        layout_main.addLayout(navi3)
        layout_main.setStretchFactor(navi, 1)
        layout_main.setStretchFactor(self.canvas1, 0)

        # sub layout
        layout_sub = QVBoxLayout()
        layout_sub.addWidget(self.canvas2)
        layout_sub.addWidget(self.label1)
        layout_sub.addWidget(self.lcd3)

        # demo
        final_demo = QHBoxLayout()
        final_demo.addLayout(obj_buttons)
        final_demo.addSpacing(30)
        final_demo.addLayout(layout_main)
        final_demo.addLayout(layout_sub)
        self.setLayout(final_demo)

        # timer
        self.timer = QTimer()
        self.timer.setSingleShot(False)
        self.timer.timeout.connect(self.on_time)

        # initialize visualize
        self.current_mask = np.zeros((self.num_frames, self.height, self.width), dtype=np.uint8)
        self.cursur = 0
        self.on_showing = None
        self.on_showing2 = None
        self.show_current()


        # initialize action
        self.reset_scribbles()
        self.pressed = False
        self.on_drawing = None
        self.drawn_strokes = []
        self.obj1_button.setChecked(True)
        self.show()

    def restart_video(self):
        self.__init__(self.DIE, self.model, self.root, self.video_idx)

    def show_candidates(self):
        sorted_score_idx = np.argsort(self.model.scores_nf)
        exclude_range = self.num_frames/10
        excluded_next_candidates = []
        self.candidate_frames = []
        for i in range(self.num_frames):
            if not sorted_score_idx[i] in excluded_next_candidates:
                self.candidate_frames.append(sorted_score_idx[i])
                excluded_next_candidates += list(range(
                    int(sorted_score_idx[i]-(exclude_range/2)+0.5), int(sorted_score_idx[i]+(exclude_range/2)+0.5)))
            if len(self.candidate_frames)==4:
                break
        self.candidate_frames = sorted(self.candidate_frames)
        canvasImg = Image.new('RGB', (self.width, self.height))
        draw = ImageDraw.Draw(canvasImg)
        cand_Img = Image.fromarray(self.vis_frames[self.candidate_frames[0]]).resize((self.width//2-4, self.height//2-2))
        canvasImg.paste(cand_Img, (2, 1))
        cand_Img = Image.fromarray(self.vis_frames[self.candidate_frames[1]]).resize((self.width//2-4, self.height//2-2))
        canvasImg.paste(cand_Img, (self.width//2 + 2, 1))
        cand_Img = Image.fromarray(self.vis_frames[self.candidate_frames[2]]).resize((self.width//2-4, self.height//2-2))
        canvasImg.paste(cand_Img, (2, self.height//2 + 1))
        cand_Img = Image.fromarray(self.vis_frames[self.candidate_frames[3]]).resize((self.width//2-4, self.height//2-2))
        canvasImg.paste(cand_Img, (self.width//2 + 2, self.height//2 + 1))
        draw.multiline_text((5, 5), 'Candidate A: Fr{:03d}'.format(self.candidate_frames[0]),
                            fill=(255, 255, 255, 255), font=self.selected_font, spacing=1.5, align="right")
        draw.multiline_text((5 + self.width//2, 5), 'Candidate B: Fr{:03d}'.format(self.candidate_frames[1]),
                            fill=(255, 255, 255, 255), font=self.selected_font, spacing=1.5, align="right")
        draw.multiline_text((5, 5 + self.height//2), 'Candidate C: Fr{:03d}'.format(self.candidate_frames[2]),
                            fill=(255, 255, 255, 255), font=self.selected_font, spacing=1.5, align="right")
        draw.multiline_text((5 + self.width//2, 5 + self.height//2), 'Candidate D: Fr{:03d}'.format(self.candidate_frames[3]),
                            fill=(255, 255, 255, 255), font=self.selected_font, spacing=1.5, align="right")
        vis_candidates = np.array(canvasImg)
        if self.on_showing:
            self.on_showing.remove()
            self.on_showing2.remove()
        self.on_showing = self.ax1.imshow(vis_candidates)

    def show_candidates_gt(self):
        canvasImg2 = Image.new('RGB', (self.width, self.height))
        cand_Img = Image.fromarray(self.gts_overlayed[self.candidate_frames[0]]).resize((self.width//2-4, self.height//2-2))
        canvasImg2.paste(cand_Img, (2, 1))
        cand_Img = Image.fromarray(self.gts_overlayed[self.candidate_frames[1]]).resize((self.width//2-4, self.height//2-2))
        canvasImg2.paste(cand_Img, (self.width//2 + 2, 1))
        cand_Img = Image.fromarray(self.gts_overlayed[self.candidate_frames[2]]).resize((self.width//2-4, self.height//2-2))
        canvasImg2.paste(cand_Img, (2, self.height//2 + 1))
        cand_Img = Image.fromarray(self.gts_overlayed[self.candidate_frames[3]]).resize((self.width//2-4, self.height//2-2))
        canvasImg2.paste(cand_Img, (self.width//2 + 2, self.height//2 + 1))
        vis_candidates_gt = np.array(canvasImg2)
        self.on_showing2 = self.ax2.imshow(vis_candidates_gt)


    def show_current(self):

        if self.on_showing:
            self.on_showing.remove()
            self.on_showing2.remove()
        self.on_showing = self.ax1.imshow(self.vis_frames[self.cursur])
        self.on_showing2 = self.ax2.imshow(self.gts_overlayed[self.cursur])
        self.canvas1.draw()
        self.canvas2.draw()
        self.lcd1.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1))
        self.slider.setValue(self.cursur)

    def show_current_anno(self):
        viz = overlay_mask(self.frames[self.cursur], self.current_mask[self.cursur], alpha=0.5, contour_thickness=2)

        if self.on_showing:
            self.on_showing.remove()
            self.on_showing2.remove()
        self.on_showing = self.ax1.imshow(viz)
        self.on_showing2 = self.ax2.imshow(self.gts_overlayed[self.cursur])
        self.canvas1.draw()
        self.canvas2.draw()
        self.lcd1.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames - 1))
        self.slider.setValue(self.cursur)

    def reset_scribbles(self):
        self.scribbles = {}
        self.scribbles['scribbles'] = [[] for _ in range(self.num_frames)]
        self.scribbles['sequence'] = self.sequence

    def clear_strokes(self):
        # clear drawn scribbles
        if len(self.drawn_strokes) > 0:
            for line in self.drawn_strokes:
                if line is not None:
                    line.pop(0).remove()
            self.drawn_strokes= []
            self.canvas1.draw()
            self.canvas2.draw()

    def slide(self):
        self.clear_strokes()
        self.reset_scribbles()
        self.cursur = self.slider.value()
        self.show_current()
        # print('slide')

    def on_candidateA(self):
        if len(self.candidate_frames) !=0:
            self.finding_timesteps.append(time.time()-self.time_init)
            self.cursur = self.candidate_frames[0]
            self. after_candidates_decided = True
            self.show_current()
    def on_candidateB(self):
        if len(self.candidate_frames) !=0:
            self.finding_timesteps.append(time.time()-self.time_init)
            self.cursur = self.candidate_frames[1]
            self. after_candidates_decided = True
            self.show_current()
    def on_candidateC(self):
        if len(self.candidate_frames) !=0:
            self.finding_timesteps.append(time.time()-self.time_init)
            self.cursur = self.candidate_frames[2]
            self. after_candidates_decided = True
            self.show_current()
    def on_candidateD(self):
        if len(self.candidate_frames) !=0:
            self.finding_timesteps.append(time.time()-self.time_init)
            self.cursur = self.candidate_frames[3]
            self. after_candidates_decided = True
            self.show_current()

    def on_run_dschange(self):
        if len(self.scribbles['scribbles'][self.cursur])>=1:
            self.text_print += 'Running VOS...\n'
            self.lcd3.setText(self.text_print)

    def on_run(self):
        if len(self.scribbles['scribbles'][self.cursur])>=1:
            self.scribble_timesteps.append(time.time()-self.time_init)
            self.VOS_once_executed_bool = True
            self.model.Run_propagation(self.cursur)
            self.current_mask = self.model.Get_mask()

            self.current_round +=1

            print('[Overlaying segmentations...]')
            for fr in range(self.num_frames):
                self.vis_frames[fr] = overlay_mask(self.frames[fr], self.current_mask[fr], alpha=0.5, contour_thickness=2)
            print('[Overlaying Done.] \n')


            # clear scribble and reset
            self.show_candidates()
            self.show_candidates_gt()
            self.after_candidates_decided = False
            self.reset_scribbles()
            self.clear_strokes()
            self.lcd2.setText('Current round : {:02d}'.format(self.current_round + 1))
            self.text_print += '\nRound [{:02d}]\n'.format(self.current_round+1)

            self.operate_timesteps.append(time.time() - self.time_init)
            self.slider.setDisabled(True)
            self.text_print += 'Finding a unsatisfying frame...\n'
            self.lcd3.setText(self.text_print)

    def on_end(self):
        if self.VOS_once_executed_bool and (len(self.scribbles['scribbles'][self.cursur])==0):
            if len(self.finding_timesteps) == (len(self.operate_timesteps)-1):
                self.finding_timesteps.append(time.time()-self.time_init)
            final_mask = self.model.Get_mask()
            final_J = np.average(batched_jaccard(self.gts, final_mask, average_over_objects=False), axis=0)  # n_obj
            final_F = np.average(batched_f_measure(self.gts, final_mask, average_over_objects=False), axis=0)  # n_obj

            self.DIE.write_in_csv(self.sequence, self.n_obj, final_J, final_F, self.scribble_timesteps, self.operate_timesteps, self.finding_timesteps)

            if self.save_imgs:
                save_dir = os.path.join('result_video', 'Alg[{}]_{}'.format(self.DIE.algorithm_name,self.DIE.current_time), '{}'.format(self.sequence))
                utils_custom.mkdir(save_dir)
                for fr_idx in range(self.num_frames):
                    savefname = os.path.join(save_dir,'{:05d}.png'.format(fr_idx))
                    tmpPIL = Image.fromarray(final_mask[fr_idx].astype(np.uint8), 'P')
                    tmpPIL.putpalette(self._palette)
                    tmpPIL.save(savefname)



            if self.video_indices is not None:
                QCoreApplication.instance().quit()

    def on_prev(self):
        self.clear_strokes()
        self.reset_scribbles()
        self.cursur = max(0, self.cursur-1)
        self.show_current()
        # print('prev')

    def on_next(self):
        self.clear_strokes()
        self.reset_scribbles()
        self.cursur = min(self.cursur+1, self.num_frames-1)
        self.show_current()
        # print('next ')

    def on_time(self):
        self.clear_strokes()
        self.reset_scribbles()
        self.cursur += 1
        if self.cursur > self.num_frames-1:
            self.cursur = 0
        self.show_current()

    def on_play(self):
        if self.timer.isActive():
            self.timer.stop()
        else:
            self.timer.start(100 / 10)

    def on_press(self, event):
        if (len(self.finding_timesteps)-len(self.operate_timesteps)==0) and self.after_candidates_decided:
            self.slider.setDisabled(True)
            if self.not_started:
                self.text_print += 'Providing scribble...\n'
                self.lcd3.setText(self.text_print)
                self.time_init = time.time()
                self.not_started = False

            if event.xdata and event.ydata:
                self.pressed = True
                self.stroke = {}
                self.stroke['path'] = []
                self.stroke['path'].append([event.xdata/self.width, event.ydata/self.height])
                if event.button == 1:
                    self.stroke['object_id'] = self.current_object
                else:
                    self.stroke['object_id'] = 0
                self.stroke['start_time'] = time.time()
                self.visualize_annotation(event)

    def on_motion(self, event):
        if (len(self.finding_timesteps)-len(self.operate_timesteps)==0) and self.after_candidates_decided:
            self.visualize_annotation(event)


    def on_release(self, event):
        if (len(self.finding_timesteps)-len(self.operate_timesteps)==0) and self.after_candidates_decided:
            self.pressed = False
            if event.xdata and event.ydata:
                self.stroke['path'].append([event.xdata/self.width, event.ydata/self.height])
            self.stroke['end_time'] = time.time()
            self.scribbles['annotated_frame'] = self.cursur
            self.scribbles['scribbles'][self.cursur].append(self.stroke)
            self.drawn_strokes.append(self.on_drawing)
            self.on_drawing = None

            self.model.Run_interaction(self.scribbles)
            self.current_mask[self.cursur] = self.model.Get_mask_index(self.cursur)
            self.show_current_anno()

    def visualize_annotation(self, event):
        if self.pressed and event.xdata and event.ydata:
            self.stroke['path'].append([event.xdata/self.width, event.ydata/self.height])

            x = [p[0]*self.width for p in self.stroke['path']]
            y = [p[1]*self.height for p in self.stroke['path']]
            if self.on_drawing:
                self.on_drawing.pop(0).remove()

            if self.stroke['object_id'] == 0:
                self.on_drawing = self.ax1.plot(x,y, marker='o', markersize=4, linewidth=5, color=[0,0,0])
            if self.stroke['object_id'] == self.current_object:
                self.on_drawing = self.ax1.plot(x,y, marker='o', markersize=4, linewidth=5, color=(self.cmap[self.current_object])/320 +0.2)
            self.canvas1.draw()

    def obj1_pressed(self):
        if self.pressed: self.obj1_button.toggle()
        else:
            self.current_object = 1
            self.obj1_button.setChecked(True),  self.obj2_button.setChecked(False), self.obj3_button.setChecked(False)
            self.obj4_button.setChecked(False), self.obj5_button.setChecked(False)
    def obj2_pressed(self):
        if self.pressed: self.obj2_button.toggle()
        else:
            if self.n_obj>=2:
                self.current_object = 2
                self.obj1_button.setChecked(False),  self.obj2_button.setChecked(True), self.obj3_button.setChecked(False)
                self.obj4_button.setChecked(False),  self.obj5_button.setChecked(False)
    def obj3_pressed(self):
        if self.pressed: self.obj3_button.toggle()
        else:
            if self.n_obj>=3:
                self.current_object = 3
                self.obj1_button.setChecked(False),  self.obj2_button.setChecked(False), self.obj3_button.setChecked(True)
                self.obj4_button.setChecked(False),  self.obj5_button.setChecked(False)
    def obj4_pressed(self):
        if self.pressed: self.obj4_button.toggle()
        else:
            if self.n_obj>=4:
                self.current_object = 4
                self.obj1_button.setChecked(False),  self.obj2_button.setChecked(False), self.obj3_button.setChecked(False)
                self.obj4_button.setChecked(True),  self.obj5_button.setChecked(False)
    def obj5_pressed(self):
        if self.pressed: self.obj5_button.toggle()
        else:
            if self.n_obj>=5:
                self.current_object = 5
                self.obj1_button.setChecked(False),  self.obj2_button.setChecked(False), self.obj3_button.setChecked(False)
                self.obj4_button.setChecked(False),  self.obj5_button.setChecked(True)

    def keyPressEvent(self, event):
        if event.key() == Qt.Key_Escape:
            self.close()